#include"residue_math.h"
#include"pkg.h"
#include"util.h"
#include"my_exceptions.h"
#include "aes.h"
#include"sha1.h"

#include<cstring>
#include <fstream>

Pkg::Pkg() {}
void Pkg::setup(const char * msk_file_name, const char * mpk_file_name,int security)
{
    	ZZ a = gen_strong_prime(security);

	ZZ b =  gen_strong_prime(security);

	set_msk(a, b);
	save_msk(msk_file_name);
	calc_mpk();
	save_mpk(mpk_file_name);

}

void Pkg::setup(string msk_file_name, string mpk_file_name, int security )
{
	setup(msk_file_name.c_str(), mpk_file_name.c_str());
}

void Pkg::set_msk(ZZ & p, ZZ & q)
{
	msk_1 = p;
	msk_2 = q;
}

void Pkg::calc_mpk()
{
	mpk = msk_1 * msk_2;
}

void Pkg::load_msk(const char * msk_file_name)
{
	std::ifstream msk_file(msk_file_name);
	if (! msk_file)
	{
		std::cout << "File "<< msk_file_name << " doesn't exist\n" ;
		throw;
	}
	msk_file >> msk_1 >> msk_2;
	msk_file.close();
}

void Pkg::load_msk(string msk_file_name)
{
	load_msk(msk_file_name.c_str());
}

void Pkg::save_msk(const char * msk_file_name)
{
	std::ofstream msk_file(msk_file_name);
	msk_file << msk_1 << " " << msk_2;
	msk_file.close();
}

void Pkg::save_mpk(const char * mpk_file_name)
{
	std::ofstream mpk_file(mpk_file_name);
	mpk_file << mpk;
	mpk_file.close();
}

/*void Pkg::save_sk(const char * sk_file_name, ZZ & sk)
{
	std::ofstream sk_file(sk_file_name);
	sk_file << sk;
	sk_file.close();
}


void Pkg::save_sk(string sk_file_name, ZZ & sk)
{
	save_sk(sk_file_name.c_str(), sk);
}*/


ZZ Pkg::get_mpk()
{
	return mpk;
}

Pair Pkg::get_msk()
{
	return Pair(msk_1, msk_2);
}


ZZ Pkg::keyextract(const ZZ & pk)
{
	//pk %= mpk;
	ZZ t = (mpk + 5 - msk_1 - msk_2)/8;
	ZZ sk = pow (pk,t,mpk);
	return sk;
}
void Pkg::keyextract_from_string(string identity, string sk_file_name)
{
	const char * s = identity.c_str();
	ZZ pk = hash(s, mpk);

	ZZ sk = keyextract(pk);
	std::ofstream sk_file(sk_file_name.c_str());
	sk_file << sk;
	sk_file.close();
}

void Pkg::keyextract(string pk_file_name, string  sk_file_name)
{

	std::ifstream pk_file(pk_file_name.c_str());
	if (! pk_file)	{

		throw FileNotFoundException(pk_file_name);
	}
	std::string identity;
	pk_file >> identity;
	const char * s = identity.c_str();
	ZZ pk = hash(s, mpk);
	pk_file.close();
	ZZ sk = keyextract(pk);
	std::ofstream sk_file(sk_file_name.c_str());
	sk_file << sk;
	sk_file.close();

}
void Pkg:: encrypt_sk(FILE * sk_file, string password, ZZ & sk)
{
	SHA1 sha;
	sha.Reset();
	unsigned int out[5];
	sha << password.c_str();
	sha.Result(out);

	unsigned char  key[KEY_LENGTH];
	memcpy(key, out, KEY_LENGTH);
	aes_context * ctx = new aes_context[KEY_LENGTH];
	aes_set_key(ctx, key, KEY_LENGTH * NUMBER_OF_BITS);

	unsigned char sk_as_chars[256];
	long sk_length = sk. size() * 8;
	memcpy(sk_as_chars, &sk_length, sizeof(long));
	BytesFromZZ( & sk_as_chars[sizeof(long)], sk, sk_length);
	size_t length = sizeof(long) + sk_length;
	size_t number_of_blocks = length / BLOCK_LENGTH;
	if ( length  % BLOCK_LENGTH) ++ number_of_blocks;
	for (size_t i  =  length; i != BLOCK_LENGTH * number_of_blocks; ++ i )
	{
		sk_as_chars [ i ] = (unsigned char) rand();
	}
	unsigned char sk_block[BLOCK_LENGTH];
	unsigned char encrypted_sk_block[BLOCK_LENGTH];

	for (size_t i = 0; i != number_of_blocks; ++i)
	{
		memcpy(sk_block, & sk_as_chars[ i * BLOCK_LENGTH], BLOCK_LENGTH);
		aes_encrypt(ctx, sk_block, encrypted_sk_block);
		fwrite(encrypted_sk_block, sizeof(char), BLOCK_LENGTH, sk_file);
	}

}
void Pkg::keyextract(string pk_file_name, string  sk_file_name, string password)
{
    std::ifstream pk_file(pk_file_name.c_str());
	if (! pk_file)	{

		throw FileNotFoundException(pk_file_name);
	}
	std::string identity;
	pk_file >> identity;
	const char * s = identity.c_str();
	ZZ pk = hash(s, mpk);
	pk_file.close();
	ZZ sk = keyextract(pk);

    FILE * sk_file =  fopen(sk_file_name.c_str(), "r");

	if (sk_file  == NULL)
	{
		throw FileNotFoundException(sk_file_name);
	}
	encrypt_sk(sk_file, password, sk);
	fclose(sk_file);
}

